{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.keras import layers,activations"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 残差块"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Residual(tf.keras.Model):\n",
    "    # 定义网络结构\n",
    "    def __init__(self,num_channels,use_1x1conv=False,strides=1):\n",
    "        super(Residual,self).__init__()\n",
    "        # 卷积层\n",
    "        self.conv1 = layers.Conv2D(num_channels,padding='same',kernel_size=3,strides=strides)\n",
    "        # 卷积层\n",
    "        self.conv2 = layers.Conv2D(num_channels,kernel_size=3,padding='same')\n",
    "        # 是否使用1*1的卷积\n",
    "        if use_1x1conv:\n",
    "            self.conv3 = layers.Conv2D(num_channels,kernel_size=1,strides=strides)\n",
    "        else:\n",
    "            self.conv3 = None\n",
    "        # BN层\n",
    "        self.bn1 = layers.BatchNormalization()\n",
    "        self.bn2 = layers.BatchNormalization()\n",
    "    # 定义前向传播过程  \n",
    "    def call(self,x):\n",
    "        Y = activations.relu(self.bn1(self.conv1(x)))\n",
    "        Y = self.bn2(self.conv2(Y))\n",
    "        if self.conv3:\n",
    "            x = self.conv3(x)\n",
    "        outputs = activations.relu(Y+x)\n",
    "        return outputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 残差模块"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ResnetBlock(tf.keras.layers.Layer):\n",
    "    # 定义所需的网络结构\n",
    "    def __init__(self,num_channels,num_res,first_block=False):\n",
    "        super(ResnetBlock,self).__init__()\n",
    "        # 存储残差块\n",
    "        self.listLayers=[]\n",
    "        # 遍历残差数目生成模块\n",
    "        for i in range(num_res):\n",
    "            # 如果是第一个残差块而且不是第一个模块时\n",
    "            if i ==0 and not first_block:\n",
    "                self.listLayers.append(Residual(num_channels,use_1x1conv=True,strides=2))\n",
    "            else:\n",
    "                self.listLayers.append(Residual(num_channels))\n",
    "    # 定义前向传播\n",
    "    def call(self,X):\n",
    "        for layer in self.listLayers.layers:\n",
    "            X = layer(X)\n",
    "        return X"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 构建resNet网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ResNet(tf.keras.Model):\n",
    "    # 定义网络的构成\n",
    "    def __init__(self, num_blocks):\n",
    "        super(ResNet, self).__init__()\n",
    "        # 输入层\n",
    "        self.conv = layers.Conv2D(64, kernel_size=7, strides=2, padding='same')\n",
    "        # BN 层\n",
    "        self.bn = layers.BatchNormalization()\n",
    "        # 激活层\n",
    "        self.relu = layers.Activation('relu')\n",
    "        # 池化\n",
    "        self.mp = layers.MaxPool2D(pool_size=3, strides=2, padding=\"same\")\n",
    "        # 残差模块\n",
    "        self.res_block1 = ResnetBlock(64, num_blocks[0], first_block=True)\n",
    "        self.res_block2 = ResnetBlock(128, num_blocks[1])\n",
    "        self.res_block3 = ResnetBlock(256, num_blocks[2])\n",
    "        self.res_block4 = ResnetBlock(512, num_blocks[3])\n",
    "        # GAP\n",
    "        self.gap = layers.GlobalAvgPool2D()\n",
    "        # 全连接层\n",
    "        self.fc = layers.Dense(\n",
    "            units=10, activation=tf.keras.activations.softmax)\n",
    "    # 定义前向传播过程\n",
    "\n",
    "    def call(self, x):\n",
    "        # 输入部分的传输过程\n",
    "        x = self.conv(x)\n",
    "        x = self.bn(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.mp(x)\n",
    "        # block\n",
    "        x = self.res_block1(x)\n",
    "        x = self.res_block2(x)\n",
    "        x = self.res_block3(x)\n",
    "        x = self.res_block4(x)\n",
    "        # 输出部分的传输\n",
    "        x = self.gap(x)\n",
    "        x = self.fc(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"res_net_5\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "conv2d_62 (Conv2D)           multiple                  3200      \n",
      "_________________________________________________________________\n",
      "batch_normalization_53 (Batc multiple                  256       \n",
      "_________________________________________________________________\n",
      "activation_5 (Activation)    multiple                  0         \n",
      "_________________________________________________________________\n",
      "max_pooling2d_5 (MaxPooling2 multiple                  0         \n",
      "_________________________________________________________________\n",
      "resnet_block_14 (ResnetBlock multiple                  148736    \n",
      "_________________________________________________________________\n",
      "resnet_block_15 (ResnetBlock multiple                  526976    \n",
      "_________________________________________________________________\n",
      "resnet_block_16 (ResnetBlock multiple                  2102528   \n",
      "_________________________________________________________________\n",
      "resnet_block_17 (ResnetBlock multiple                  8399360   \n",
      "_________________________________________________________________\n",
      "global_average_pooling2d_3 ( multiple                  0         \n",
      "_________________________________________________________________\n",
      "dense_3 (Dense)              multiple                  5130      \n",
      "=================================================================\n",
      "Total params: 11,186,186\n",
      "Trainable params: 11,178,378\n",
      "Non-trainable params: 7,808\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "# 实例化\n",
    "mynet = ResNet([2,2,2,2])\n",
    "X = tf.random.uniform((1,224,224,1))\n",
    "y = mynet(X)\n",
    "mynet.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 数据读取"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from tensorflow.keras.datasets import mnist\n",
    "# 获取手写数字数据集\n",
    "(train_images, train_labels), (test_images, test_labels) = mnist.load_data()\n",
    "# 训练集数据维度的调整：N H W C\n",
    "train_images = np.reshape(train_images,(train_images.shape[0],train_images.shape[1],train_images.shape[2],1))\n",
    "# 测试集数据维度的调整：N H W C\n",
    "test_images = np.reshape(test_images,(test_images.shape[0],test_images.shape[1],test_images.shape[2],1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义两个方法随机抽取部分样本演示\n",
    "# 获取训练集数据\n",
    "def get_train(size):\n",
    "    # 随机生成要抽样的样本的索引\n",
    "    index = np.random.randint(0, np.shape(train_images)[0], size)\n",
    "    # 将这些数据resize成22*227大小\n",
    "    resized_images = tf.image.resize_with_pad(train_images[index],224,224,)\n",
    "    # 返回抽取的\n",
    "    return resized_images.numpy(), train_labels[index]\n",
    "# 获取测试集数据 \n",
    "def get_test(size):\n",
    "    # 随机生成要抽样的样本的索引\n",
    "    index = np.random.randint(0, np.shape(test_images)[0], size)\n",
    "    # 将这些数据resize成224*224大小\n",
    "    resized_images = tf.image.resize_with_pad(test_images[index],224,224,)\n",
    "    # 返回抽样的测试样本\n",
    "    return resized_images.numpy(), test_labels[index]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 获取训练样本和测试样本\n",
    "train_images,train_labels = get_train(256)\n",
    "test_images,test_labels = get_test(128)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 模型编译"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 指定优化器，损失函数和评价指标\n",
    "optimizer = tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.0)\n",
    "\n",
    "mynet.compile(optimizer=optimizer,\n",
    "              loss='sparse_categorical_crossentropy',\n",
    "              metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 模型训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/3\n",
      "2/2 [==============================] - 8s 4s/step - loss: 2.7646 - accuracy: 0.1348 - val_loss: 4.4041 - val_accuracy: 0.0769\n",
      "Epoch 2/3\n",
      "2/2 [==============================] - 7s 4s/step - loss: 2.1875 - accuracy: 0.2826 - val_loss: 5.0271 - val_accuracy: 0.0769\n",
      "Epoch 3/3\n",
      "2/2 [==============================] - 7s 4s/step - loss: 1.9738 - accuracy: 0.3435 - val_loss: 3.6854 - val_accuracy: 0.3077\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x13e4ef438>"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 模型训练：指定训练数据，batchsize,epoch,验证集\n",
    "mynet.fit(train_images,train_labels,batch_size=128,epochs=3,verbose=1,validation_split=0.1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 模型评估"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4/4 [==============================] - 1s 342ms/step - loss: 4.9585 - accuracy: 0.1094\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[4.958542823791504, 0.109375]"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 指定测试数据\n",
    "mynet.evaluate(test_images,test_labels,verbose=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
